#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################
import functools
import torch
from . import function
from . import quant_ste


###################################################
round_g = quant_ste.PropagateQuantTensorSTE(function.RoundG.apply)
round_sym_g = quant_ste.PropagateQuantTensorSTE(function.RoundSymG.apply)
round_up_g = quant_ste.PropagateQuantTensorSTE(function.RoundUpG.apply)
round2_g = quant_ste.PropagateQuantTensorSTE(function.Round2G.apply)
ceil_g = quant_ste.PropagateQuantTensorSTE(function.CeilG.apply)
ceil2_g = quant_ste.PropagateQuantTensorSTE(function.Ceil2G.apply)

# This line with PropagateQuantTensorSTE is optional: using PropagateQuantTensorSTE will cause
# backward method of QuantizeDequantizeG to be skipped
# Replace with PropagateQuantTensorQTE to: allow gradient to flow back through the backward method
# Note: QTE here has effect only if QTE is used in forward() QuantTrainPAct2 in quant_train_modules.py
# by using quantize_backward_type = 'qte' in QuantTrainPAct2
# TODO: when using QTE here, we need to register this OP for ONNX export to work
# and even then the exported model may not be clean.
#quantize_dequantize_g = quant_ste.PropagateQuantTensorSTE(function.TorchQuantizeDequantizeG.apply)
quantize_dequantize_g = quant_ste.PropagateQuantTensorSTE(function.QuantizeDequantizeG.apply)


###################################################
def clamp_g(x, min, max, training, inplace=False, requires_grad=False):
    if x is None:
        return x
    #
    # in eval mode, torch.clamp can be used
    # the graph exported in eval mode will be simpler and have fixed constants that way.
    if training:
        if requires_grad:
            # torch's clamp doesn't currently work with min and max as tensors
            # TODO: replace with this, when torch clamp supports tensor arguments:
            # TODO:switch back to min/max if you want to lean the clip values by backprop
            zero_tensor = torch.zeros_like(x.view(-1)[0])
            min = zero_tensor + min
            max = zero_tensor + max
            y = torch.max(torch.min(x, max), min)
        else:
            # clamp takes less memory - using it for now
            y = torch.clamp_(x, min, max) if inplace else torch.clamp(x, min, max)
        #
    else:
        # use the params as constants for easy representation in onnx graph
        y = torch.clamp_(x, float(min), float(max)) if inplace else torch.clamp(x, float(min), float(max))
    #
    return y


###################################################
# from torchvision shufflenetv2
# https://github.com/pytorch/vision/blob/master/torchvision/models/shufflenetv2.py
def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    channels_per_group = num_channels // groups

    # reshape
    x = x.view(batchsize, groups, channels_per_group, height, width)

    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x


###################################################
def channel_split_by_chunks(x, chunks):
    branches = torch.chunk(x,chunks,dim=1)
    return branches


def channel_split_by_size(x, split_size):
    branches = torch.split(x,split_size,dim=1)
    return branches


def channel_split_by_index(x, split_index):
    split_list = [split_index, int(x.size(1))-split_index]
    branches = torch.split(x,split_list,dim=1)
    return branches


def channel_slice_by_index(x, split_index):
    return (x[:,:split_index,...],x[:,split_index:,...])


###################################################
def crop_grid(features, crop_offsets, crop_size):
    b = int(features.size(0))
    h = int(features.size(2))
    w = int(features.size(3))
    zero = features.new(b, 1).zero_()
    one = zero + 1.0
    y_offset = crop_offsets[:,0] / (h-1)
    x_offset = crop_offsets[:,1] / (w-1)
    theta = torch.cat([one, zero, x_offset, zero, one, y_offset], dim=1).view(-1,2,3)
    grid = torch.nn.functional.affine_gird(theta, (b, 1, crop_size[0], crop_size[1]))
    out = torch.nn.functional.grid_sample(features, grid)
    return out


###################################################
def split_output_channels(output, output_channels):
    if isinstance(output, (list, tuple)):
        return output
    elif len(output_channels) == 1:
        return [output]
    else:
        start_ch = 0
        task_outputs = []
        for num_ch in output_channels:
            if len(output.shape) == 3:
                task_outputs.append(output[start_ch:(start_ch + num_ch), ...])
            elif len(output.shape) == 4:
                task_outputs.append(output[:, start_ch:(start_ch + num_ch), ...])
            else:
                assert False, 'incorrect dimensions'
            # --
            start_ch += num_ch
        # --
        return task_outputs
